{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### transform original pairrm model to hf format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/dongfu/miniconda3/envs/llm-blender/lib/python3.9/site-packages/transformers/convert_slow_tokenizer.py:473: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Embedding(128005, 1024)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from llm_blender.pair_ranker.pairrm import DebertaV2PairRM\n",
    "from transformers import DebertaV2Config, AutoTokenizer\n",
    "config = DebertaV2Config.from_pretrained('microsoft/deberta-v3-large')\n",
    "tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-large')\n",
    "source_prefix = \"<|source|>\"\n",
    "cand1_prefix = \"<|candidate1|>\"\n",
    "cand2_prefix = \"<|candidate2|>\"\n",
    "cand_prefix = \"<|candidate|>\"\n",
    "tokenizer.add_tokens([source_prefix, cand1_prefix, cand2_prefix, cand_prefix])\n",
    "\n",
    "config.n_tasks = 1\n",
    "config.source_prefix_id = 128001\n",
    "config.cand1_prefix_id = 128002\n",
    "config.cand2_prefix_id = 128003\n",
    "config.cand_prefix_id = 128004\n",
    "config.drop_out = 0.05\n",
    "pairrm = DebertaV2PairRM(config)\n",
    "pairrm.pretrained_model.resize_token_embeddings(len(tokenizer))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Successfully loaded checkpoint from './PairRM/model.safetensors'\n"
     ]
    }
   ],
   "source": [
    "!git clone https://huggingface.co/llm-blender/PairRM\n",
    "import safetensors\n",
    "import logging\n",
    "\n",
    "load_result = safetensors.torch.load_model(pairrm, \"./PairRM/model.safetensors\") # path of original pairrm model\n",
    "missing_keys, unexpected_keys = load_result\n",
    "if missing_keys:\n",
    "    print(f\"Missing keys: {missing_keys}\")\n",
    "if unexpected_keys:\n",
    "    print(f\"Unexpected keys: {unexpected_keys}\")\n",
    "if not missing_keys and not unexpected_keys:\n",
    "    print(f\"Successfully loaded checkpoint from './PairRM/model.safetensors'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Trainer, TrainingArguments\n",
    "trainer = Trainer(\n",
    "    model=pairrm,\n",
    "    args=TrainingArguments(\n",
    "        output_dir=\"./hf_PairRM\",\n",
    "        overwrite_output_dir=True,\n",
    "    ),\n",
    "    tokenizer=tokenizer,\n",
    ")\n",
    "trainer.save_model(\"./hf_PairRM/final_checkpoint\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Verifying Correctness"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### load hf_format pairrm using `from_pretrained` "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "You're using a DebertaV2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1.9003021717071533, -1.2547134160995483]\n",
      "tensor([ True, False], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "from llm_blender.pair_ranker.pairrm import DebertaV2PairRM\n",
    "from transformers import AutoTokenizer\n",
    "from typing import List\n",
    "pairrm = DebertaV2PairRM.from_pretrained(\"llm-blender/PairRM-hf\", device_map=\"cuda:0\").eval()\n",
    "tokenizer = AutoTokenizer.from_pretrained('llm-blender/PairRM-hf')\n",
    "source_prefix = \"<|source|>\"\n",
    "cand1_prefix = \"<|candidate1|>\"\n",
    "cand2_prefix = \"<|candidate2|>\"\n",
    "inputs = [\"hello!\", \"I love you!\"]\n",
    "candidates_A = [\"hi!\", \"I hate you!\"]\n",
    "candidates_B = [\"f**k off!\", \"I love you, too!\"]\n",
    "def tokenize_pair(sources:List[str], candidate1s:List[str], candidate2s:List[str], source_max_length=1224, candidate_max_length=412):\n",
    "    ids = []\n",
    "    assert len(sources) == len(candidate1s) == len(candidate2s)\n",
    "    max_length = source_max_length + 2 * candidate_max_length\n",
    "    for i in range(len(sources)):\n",
    "        source_ids = tokenizer.encode(source_prefix + sources[i], max_length=source_max_length, truncation=True)\n",
    "        candidate_max_length = (max_length - len(source_ids)) // 2\n",
    "        candidate1_ids = tokenizer.encode(cand1_prefix + candidate1s[i], max_length=candidate_max_length, truncation=True)\n",
    "        candidate2_ids = tokenizer.encode(cand2_prefix + candidate2s[i], max_length=candidate_max_length, truncation=True)\n",
    "        ids.append(source_ids + candidate1_ids + candidate2_ids)\n",
    "    encodings = tokenizer.pad({\"input_ids\": ids}, return_tensors=\"pt\", padding=\"max_length\", max_length=max_length)\n",
    "    return encodings\n",
    "\n",
    "encodings = tokenize_pair(inputs, candidates_A, candidates_B)\n",
    "encodings = {k:v.to(pairrm.device) for k,v in encodings.items()}\n",
    "outputs = pairrm(**encodings)\n",
    "logits = outputs.logits.tolist()\n",
    "comparison_results = outputs.logits > 0\n",
    "print(logits)\n",
    "# [1.9003021717071533, -1.2547134160995483]\n",
    "print(comparison_results)\n",
    "# tensor([ True, False], device='cuda:0'), which means whether candidate A is better than candidate B for each input\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### load from llm-blender wrapper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:No ranker config provided, no ranker loaded, please load ranker first through load_ranker()\n",
      "WARNING:root:No fuser config provided, no fuser loaded, please load fuser first through load_fuser()\n",
      "/home/dongfu/miniconda3/envs/llm-blender/lib/python3.9/site-packages/dataclasses_json/core.py:187: RuntimeWarning: 'NoneType' object value of non-optional type load_checkpoint detected when decoding RankerConfig.\n",
      "  warnings.warn(\n",
      "/home/dongfu/miniconda3/envs/llm-blender/lib/python3.9/site-packages/transformers/convert_slow_tokenizer.py:473: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Successfully loaded ranker from  /home/dongfu/data/.cache/huggingface/hub/llm-blender/PairRM\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Ranking candidates: 100%|██████████| 1/1 [00:00<00:00, 18.13it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 1.9   -1.255]\n",
      "[ True False]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "import llm_blender\n",
    "blender = llm_blender.Blender()\n",
    "# Load Ranker\n",
    "blender.loadranker(\"llm-blender/PairRM\") # load ranker checkpoint\n",
    "inputs = [\"hello!\", \"I love you!\"]\n",
    "candidates_A = [\"hi!\", \"I hate you!\"]\n",
    "candidates_B = [\"f**k off!\", \"I love you, too!\"]\n",
    "logits = blender.compare(inputs, candidates_A, candidates_B, return_logits=True, mode=\"[A,B]\")\n",
    "comparison_results = logits > 0\n",
    "print(logits)\n",
    "# [1.9003021717071533, -1.2547134160995483]\n",
    "print(comparison_results)\n",
    "# tensor([ True, False], device='cuda:0'), which means whether candidate A is better than candidate B for each input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_conv_pair(convAs: List[str], convBs: List[str]):\n",
    "    \"\"\"Compare two conversations by takeing USER turns as inputs and ASSISTANT turns as candidates\n",
    "        Multi-turn conversations comparison is also supportted.\n",
    "        a conversation format is:\n",
    "        ```python\n",
    "        [\n",
    "            {\n",
    "                \"content\": \"hello\",\n",
    "                \"role\": \"USER\"\n",
    "            },\n",
    "            {\n",
    "                \"content\": \"hi\",\n",
    "                \"role\": \"ASSISTANT\"\n",
    "            },\n",
    "            ...\n",
    "        ]\n",
    "        ```\n",
    "    Args:\n",
    "        convAs (List[List[dict]]): List of conversations\n",
    "        convAs (List[List[dict]]): List of conversations\n",
    "    \"\"\"\n",
    "\n",
    "    for c in convAs + convBs:\n",
    "        assert len(c) % 2 == 0, \"Each conversation must have even number of turns\"\n",
    "        assert all([c[i]['role'] == 'USER' for i in range(0, len(c), 2)]), \"Each even turn must be USER\"\n",
    "        assert all([c[i]['role'] == 'ASSISTANT' for i in range(1, len(c), 2)]), \"Each odd turn must be ASSISTANT\"\n",
    "    # check conversations correctness\n",
    "    assert len(convAs) == len(convBs), \"Number of conversations must be the same\"\n",
    "    for c_a, c_b in zip(convAs, convBs):\n",
    "        assert len(c_a) == len(c_b), \"Number of turns in each conversation must be the same\"\n",
    "        assert all([c_a[i]['content'] == c_b[i]['content'] for i in range(0, len(c_a), 2)]), \"USER turns must be the same\"\n",
    "    \n",
    "    instructions = [\"Finish the following coversation in each i-th turn by filling in <Response i> with your response.\"] * len(convAs)\n",
    "    inputs = [\n",
    "        \"\\n\".join([\n",
    "            \"USER: \" + x[i]['content'] +\n",
    "            f\"\\nAssistant: <Response {i//2+1}>\" for i in range(0, len(x), 2)\n",
    "        ]) for x in convAs\n",
    "    ]\n",
    "    cand1_texts = [\n",
    "        \"\\n\".join([\n",
    "            f\"<Response {i//2+1}>: \" + x[i]['content'] for i in range(1, len(x), 2)\n",
    "        ]) for x in convAs\n",
    "    ]\n",
    "    cand2_texts = [\n",
    "        \"\\n\".join([\n",
    "            f\"<Response {i//2+1}>: \" + x[i]['content'] for i in range(1, len(x), 2)\n",
    "        ]) for x in convBs\n",
    "    ]\n",
    "    inputs = [inst + inp for inst, inp in zip(instructions, inputs)]\n",
    "    encodings = tokenize_pair(inputs, cand1_texts, cand2_texts)\n",
    "    return encodings"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm-blender",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
